{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import keras\n",
    "import wandb\n",
    "from wandb.keras import WandbCallback\n",
    "from keras.models import Sequential, Model\n",
    "from keras.layers import Flatten, Dense, Concatenate, Dot, Lambda, Input\n",
    "from keras.datasets import mnist\n",
    "from keras.optimizers import Adam\n",
    "import matplotlib.pyplot as plt\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train /= 255\n",
    "x_test /= 255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make pairs\n",
    "def make_pairs(x, y):\n",
    "    num_classes = max(y) + 1\n",
    "    digit_indices = [np.where(y == i)[0] for i in range(num_classes)]\n",
    "\n",
    "    pairs = []\n",
    "    labels = []\n",
    "\n",
    "    for idx1 in range(len(x)):\n",
    "        # add a matching example\n",
    "        x1 = x[idx1]\n",
    "        label1 = y[idx1]\n",
    "        idx2 = random.choice(digit_indices[label1])\n",
    "        x2 = x[idx2]\n",
    "        \n",
    "        pairs += [[x1, x2]]\n",
    "        labels += [1]\n",
    "    \n",
    "        # add a not matching example\n",
    "        label2 = random.randint(0, num_classes-1)\n",
    "        while label2 == label1:\n",
    "            label2 = random.randint(0, num_classes-1)\n",
    "\n",
    "        idx2 = random.choice(digit_indices[label2])\n",
    "        x2 = x[idx2]\n",
    "        \n",
    "        pairs += [[x1, x2]]\n",
    "        labels += [0]\n",
    "\n",
    "    return np.array(pairs), np.array(labels)\n",
    "\n",
    "pairs_train, labels_train = make_pairs(x_train, y_train)\n",
    "pairs_test, labels_test = make_pairs(x_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADHlJREFUeJzt3W2MXPV5hvHrqTGm5UXCoTjGWMFNUFQXKSTdOmmC0lQ0EaGRDIpC46rUkWidVlA1VT4U0YrSbzRqEkWkSusEK06UkLRKEHwgaahbFaEixOIS3gyFIiOwjJfUaTFJMX55+mGP6WLvnl3Pnpkzy3P9pNXOnOfszq2Rb5+ZObPzj8xEUj0/03cASf2w/FJRll8qyvJLRVl+qSjLLxVl+aWiLL9UlOWXijpllDd2aqzI0zh9lDcplfIKP+HVPBgL2XdR5Y+Iy4AvAMuAr2TmzW37n8bpvDsuXcxNSmpxf+5Y8L4DP+yPiGXA3wAfBtYDmyJi/aC/T9JoLeY5/wbg6cx8JjNfBb4FbOwmlqRhW0z51wDPzbj+fLPtdSJiS0RMRsTkIQ4u4uYkdWnor/Zn5tbMnMjMieWsGPbNSVqgxZR/D7B2xvXzm22SloDFlP8B4MKIWBcRpwIfB+7sJpakYRv4VF9mHo6I64B/ZPpU37bMfKyzZJKGalHn+TPzLuCujrJIGiHf3isVZfmloiy/VJTll4qy/FJRll8qyvJLRVl+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUSNdoltvPKese0vr/GPfu2/O2SfOmmr92SN5tHU+MfnbrfNzNz7ROq/OI79UlOWXirL8UlGWXyrK8ktFWX6pKMsvFbWo8/wRsRs4ABwBDmfmRBehND7il3+pdb7ub59unW86c8+cs4M5UKTXHM1Y3C8oros3+fx6Zv6og98jaYR82C8VtdjyJ/CDiHgwIrZ0EUjSaCz2Yf8lmbknIs4F7o6IJzLznpk7NP8pbAE4jZ9b5M1J6sqijvyZuaf5PgXcDmyYZZ+tmTmRmRPLWbGYm5PUoYHLHxGnR8SZxy4DHwIe7SqYpOFazMP+VcDtEXHs93wzM7/fSSpJQzdw+TPzGeAdHWbRGHry2p9tnd9x3r0jSqKueapPKsryS0VZfqkoyy8VZfmloiy/VJTll4qy/FJRll8qyvJLRVl+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1RUF6v0agl78Q9/tXX+lV/bOqIkGjWP/FJRll8qyvJLRVl+qSjLLxVl+aWiLL9U1Lzn+SNiG/ARYCozL2q2rQS+DVwA7AauyswfDy+mhuV/z43W+SWnvTKiJCf6830bWudrrtnXOj/SZZg3oIUc+b8KXHbctuuBHZl5IbCjuS5pCZm3/Jl5D7D/uM0bge3N5e3AFR3nkjRkgz7nX5WZe5vLLwCrOsojaUQW/YJfZiaQc80jYktETEbE5CEOLvbmJHVk0PLvi4jVAM33qbl2zMytmTmRmRPLWTHgzUnq2qDlvxPY3FzeDNzRTRxJozJv+SPiNuA+4O0R8XxEXAPcDHwwIp4CfqO5LmkJmfc8f2ZummN0acdZNATLznlT6/yVVYdHlOREPz16qHX+/Wd/sXV+3n893mWccnyHn1SU5ZeKsvxSUZZfKsryS0VZfqkoP7r7DaDtdN4TN76t9Wef2PjFruMs2Psnf691ft6VnsobJo/8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1SU5/nfAA6/fe2cs10f7e88/nzO/93nWudHR5SjKo/8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUZZfKsryS0VZfqkoyy8VZfmloiy/VNS8f88fEduAjwBTmXlRs+0m4PeBF5vdbsjMu4YVUu1OeXLuv4tf/w9/1Pqzj3/slq7jaIlYyJH/q8Bls2z/fGZe3HxZfGmJmbf8mXkPsH8EWSSN0GKe818XEQ9HxLaIOLuzRJJGYtDyfwl4K3AxsBf47Fw7RsSWiJiMiMlDHBzw5iR1baDyZ+a+zDySmUeBLwMbWvbdmpkTmTmxnBWD5pTUsYHKHxGrZ1y9Eni0mziSRmUhp/puAz4AnBMRzwN/AXwgIi4GEtgNfHKIGSUNwbzlz8xNs2y+dQhZNKCja9885+wzl39zhEm0lPgOP6koyy8VZfmloiy/VJTll4qy/FJRLtG9BLznh4da539w9t/NOVu5rN93Vf7m78z9FpBlL//7CJPoeB75paIsv1SU5ZeKsvxSUZZfKsryS0VZfqkoz/OPgXzvO1rn//yXp7XOb7jloS7jnJSt//221vnyF3865+xoZtdxdBI88ktFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUZ7nHwPrb3msdf5Xb75vRElOdOPUr7TOH/yTd7XOlz26s8s46pBHfqkoyy8VZfmloiy/VJTll4qy/FJRll8qat7z/BGxFvgasApIYGtmfiEiVgLfBi4AdgNXZeaPhxdVfdi5f23r/JT/Odg69y/2x9dCjvyHgU9n5nrgPcC1EbEeuB7YkZkXAjua65KWiHnLn5l7M3Nnc/kAsAtYA2wEtje7bQeuGFZISd07qef8EXEB8E7gfmBVZu5tRi8w/bRA0hKx4PJHxBnAd4BPZeZLM2eZmczx9C4itkTEZERMHqL9+aGk0VlQ+SNiOdPF/0ZmfrfZvC8iVjfz1cDUbD+bmVszcyIzJ5bT76KRkv7fvOWPiABuBXZl5udmjO4ENjeXNwN3dB9P0rAs5E963wdcDTwSEcc+I/oG4Gbg7yPiGuBZ4KrhRFz69lz/3tb5jWd/cURJTt7UgTNa52teebV1fqTLMOrUvOXPzHuBmGN8abdxJI2K7/CTirL8UlGWXyrK8ktFWX6pKMsvFeVHd4/Aef/6k9b5zqvXtc4nVjzdZZyTcujBs1vnR3b924iSqGse+aWiLL9UlOWXirL8UlGWXyrK8ktFWX6pKM/zj8BTW9rv5t866/F5fsOp3YWRGh75paIsv1SU5ZeKsvxSUZZfKsryS0VZfqmomF5pazTOipX57vDTvqVhuT938FLun+uj9l/HI79UlOWXirL8UlGWXyrK8ktFWX6pKMsvFTVv+SNibUT8S0Q8HhGPRcQfN9tviog9EfFQ83X58ONK6spCPszjMPDpzNwZEWcCD0bE3c3s85n518OLJ2lY5i1/Zu4F9jaXD0TELmDNsINJGq6Tes4fERcA7wTubzZdFxEPR8S2iJh1XaeI2BIRkxExeYiDiworqTsLLn9EnAF8B/hUZr4EfAl4K3Ax048MPjvbz2Xm1sycyMyJ5azoILKkLiyo/BGxnOnifyMzvwuQmfsy80hmHgW+DGwYXkxJXVvIq/0B3ArsyszPzdi+esZuVwKPdh9P0rAs5NX+9wFXA49ExEPNthuATRFxMZDAbuCTQ0koaSgW8mr/vcBsfx98V/dxJI2K7/CTirL8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUZZfKsryS0VZfqkoyy8VNdIluiPiReDZGZvOAX40sgAnZ1yzjWsuMNugusz2lsz8+YXsONLyn3DjEZOZOdFbgBbjmm1cc4HZBtVXNh/2S0VZfqmovsu/tefbbzOu2cY1F5htUL1k6/U5v6T+9H3kl9STXsofEZdFxJMR8XREXN9HhrlExO6IeKRZeXiy5yzbImIqIh6dsW1lRNwdEU8132ddJq2nbGOxcnPLytK93nfjtuL1yB/2R8Qy4D+ADwLPAw8AmzLz8ZEGmUNE7AYmMrP3c8IR8X7gZeBrmXlRs+0zwP7MvLn5j/PszPzTMcl2E/By3ys3NwvKrJ65sjRwBfAJerzvWnJdRQ/3Wx9H/g3A05n5TGa+CnwL2NhDjrGXmfcA+4/bvBHY3lzezvQ/npGbI9tYyMy9mbmzuXwAOLaydK/3XUuuXvRR/jXAczOuP894LfmdwA8i4sGI2NJ3mFmsapZNB3gBWNVnmFnMu3LzKB23svTY3HeDrHjdNV/wO9Elmfku4MPAtc3D27GU08/Zxul0zYJWbh6VWVaWfk2f992gK153rY/y7wHWzrh+frNtLGTmnub7FHA747f68L5ji6Q236d6zvOacVq5ebaVpRmD+26cVrzuo/wPABdGxLqIOBX4OHBnDzlOEBGnNy/EEBGnAx9i/FYfvhPY3FzeDNzRY5bXGZeVm+daWZqe77uxW/E6M0f+BVzO9Cv+/wn8WR8Z5sj1C8APm6/H+s4G3Mb0w8BDTL82cg3wJmAH8BTwT8DKMcr2deAR4GGmi7a6p2yXMP2Q/mHgoebr8r7vu5ZcvdxvvsNPKsoX/KSiLL9UlOWXirL8UlGWXyrK8ktFWX6pKMsvFfV/V+LUNDabNTIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# take a peek at the data\n",
    "plt.imshow(pairs_train[400,1])\n",
    "print(labels_train[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Weights not shared\n",
    "\n",
    "seq1 = Sequential()\n",
    "seq1.add(Flatten(input_shape=(28,28)))\n",
    "seq1.add(Dense(128, activation='relu'))\n",
    "\n",
    "seq2 = Sequential()\n",
    "seq2.add(Flatten(input_shape=(28,28)))\n",
    "seq2.add(Dense(128, activation='relu'))\n",
    "\n",
    "merge_layer = Concatenate()([seq1.output, seq2.output])\n",
    "dense_layer = Dense(1, activation=\"sigmoid\")(merge_layer)\n",
    "model = Model(inputs=[seq1.input, seq2.input], outputs=dense_layer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "flatten_1_input (InputLayer)    (None, 28, 28)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "flatten_2_input (InputLayer)    (None, 28, 28)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 784)          0           flatten_1_input[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "flatten_2 (Flatten)             (None, 784)          0           flatten_2_input[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 128)          100480      flatten_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 128)          100480      flatten_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 256)          0           dense_1[0][0]                    \n",
      "                                                                 dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 1)            257         concatenate_1[0][0]              \n",
      "==================================================================================================\n",
      "Total params: 201,217\n",
      "Trainable params: 201,217\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.compile(loss = \"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W&B Run: https://app.wandb.ai/l2k2/siamese/runs/6v1gb888\n",
      "Call `%%wandb` in the cell containing your training loop to display live results.\n",
      "Epoch 1/10\n",
      "120000/120000 [==============================] - 38s 317us/step - loss: 0.6947 - acc: 0.5010\n",
      "Epoch 2/10\n",
      "120000/120000 [==============================] - 38s 314us/step - loss: 0.6931 - acc: 0.5054\n",
      "Epoch 3/10\n",
      "120000/120000 [==============================] - 38s 317us/step - loss: 0.6930 - acc: 0.5072\n",
      "Epoch 4/10\n",
      "120000/120000 [==============================] - 35s 290us/step - loss: 0.6928 - acc: 0.5087\n",
      "Epoch 5/10\n",
      "120000/120000 [==============================] - 35s 289us/step - loss: 0.6924 - acc: 0.5082\n",
      "Epoch 6/10\n",
      "120000/120000 [==============================] - 35s 289us/step - loss: 0.6920 - acc: 0.5130\n",
      "Epoch 7/10\n",
      "120000/120000 [==============================] - 37s 309us/step - loss: 0.6914 - acc: 0.5140\n",
      "Epoch 8/10\n",
      "120000/120000 [==============================] - 38s 314us/step - loss: 0.6906 - acc: 0.5168\n",
      "Epoch 9/10\n",
      "120000/120000 [==============================] - 40s 331us/step - loss: 0.6899 - acc: 0.5220\n",
      "Epoch 10/10\n",
      "120000/120000 [==============================] - 35s 295us/step - loss: 0.6890 - acc: 0.5242\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f5a88c39320>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.init(project=\"siamese\")\n",
    "model.fit([pairs_train[:,0], pairs_train[:,1]], labels_train[:], batch_size=16, epochs= 10, callbacks=[WandbCallback()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = Input((28,28))\n",
    "x = Flatten()(input)\n",
    "x = Dense(128, activation='relu')(x)\n",
    "dense = Model(input, x)\n",
    "\n",
    "input1 = Input((28,28))\n",
    "input2 = Input((28,28))\n",
    "\n",
    "dense1 = dense(input1)\n",
    "dense2 = dense(input2)\n",
    "\n",
    "merge_layer = Concatenate()([dense1, dense2])\n",
    "dense_layer = Dense(1, activation=\"sigmoid\")(merge_layer)\n",
    "model = Model(inputs=[input1, input2], outputs=dense_layer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_5 (InputLayer)            (None, 28, 28)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_6 (InputLayer)            (None, 28, 28)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "model_4 (Model)                 (None, 128)          100480      input_5[0][0]                    \n",
      "                                                                 input_6[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_3 (Concatenate)     (None, 256)          0           model_4[1][0]                    \n",
      "                                                                 model_4[2][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_7 (Dense)                 (None, 1)            257         concatenate_3[0][0]              \n",
      "==================================================================================================\n",
      "Total params: 100,737\n",
      "Trainable params: 100,737\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.compile(loss = \"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W&B Run: https://app.wandb.ai/l2k2/siamese/runs/ah312s1c\n",
      "Call `%%wandb` in the cell containing your training loop to display live results.\n",
      "Epoch 1/10\n",
      "120000/120000 [==============================] - 37s 306us/step - loss: 0.6941 - acc: 0.4981\n",
      "Epoch 2/10\n",
      "120000/120000 [==============================] - 35s 288us/step - loss: 0.6933 - acc: 0.5082\n",
      "Epoch 3/10\n",
      " 25248/120000 [=====>........................] - ETA: 25s - loss: 0.6930 - acc: 0.5032"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-17-661971394d32>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"siamese\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpairs_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpairs_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mWandbCallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[1;32m   1037\u001b[0m                                         \u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1038\u001b[0m                                         \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m                                         validation_steps=validation_steps)\n\u001b[0m\u001b[1;32m   1040\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1041\u001b[0m     def evaluate(self, x=None, y=None,\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)\u001b[0m\n\u001b[1;32m    197\u001b[0m                     \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m                 \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    200\u001b[0m                 \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    201\u001b[0m                 \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m   2713\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2715\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2716\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2717\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m   2673\u001b[0m             \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2674\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2675\u001b[0;31m             \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2676\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2677\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1397\u001b[0m           ret = tf_session.TF_SessionRunCallable(\n\u001b[1;32m   1398\u001b[0m               \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1399\u001b[0;31m               run_metadata_ptr)\n\u001b[0m\u001b[1;32m   1400\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1401\u001b[0m           \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "wandb.init(project=\"siamese\")\n",
    "model.fit([pairs_train[:,0], pairs_train[:,1]], labels_train[:], batch_size=16, epochs=10, callbacks=[WandbCallback()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "\n",
    "def euclidean_distance(vects):\n",
    "    x, y = vects\n",
    "    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)\n",
    "    return K.sqrt(K.maximum(sum_square, K.epsilon()))\n",
    "\n",
    "\n",
    "input = Input((28,28))\n",
    "x = Flatten()(input)\n",
    "x = Dense(128, activation='relu')(x)\n",
    "dense = Model(input, x)\n",
    "\n",
    "input1 = Input((28,28))\n",
    "input2 = Input((28,28))\n",
    "\n",
    "dense1 = dense(input1)\n",
    "dense2 = dense(input2)\n",
    "\n",
    "merge_layer = Lambda(euclidean_distance)([dense1,dense2])\n",
    "dense_layer = Dense(1, activation=\"sigmoid\")(merge_layer)\n",
    "model = Model(inputs=[input1, input2], outputs=dense_layer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_8 (InputLayer)            (None, 28, 28)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_9 (InputLayer)            (None, 28, 28)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "model_6 (Model)                 (None, 128)          100480      input_8[0][0]                    \n",
      "                                                                 input_9[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               (None, 1)            0           model_6[1][0]                    \n",
      "                                                                 model_6[2][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_9 (Dense)                 (None, 1)            2           lambda_1[0][0]                   \n",
      "==================================================================================================\n",
      "Total params: 100,482\n",
      "Trainable params: 100,482\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.compile(loss = \"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W&B Run: https://app.wandb.ai/l2k2/siamese/runs/34u8rkzs\n",
      "Call `%%wandb` in the cell containing your training loop to display live results.\n",
      "Epoch 1/10\n",
      "120000/120000 [==============================] - 39s 324us/step - loss: 0.4544 - acc: 0.7934\n",
      "Epoch 2/10\n",
      "120000/120000 [==============================] - 40s 337us/step - loss: 0.3598 - acc: 0.8500\n",
      "Epoch 3/10\n",
      "120000/120000 [==============================] - 40s 336us/step - loss: 0.3391 - acc: 0.8585\n",
      "Epoch 4/10\n",
      "120000/120000 [==============================] - 41s 344us/step - loss: 0.3251 - acc: 0.8645\n",
      "Epoch 5/10\n",
      "120000/120000 [==============================] - 41s 340us/step - loss: 0.3147 - acc: 0.8700\n",
      "Epoch 6/10\n",
      "120000/120000 [==============================] - 39s 328us/step - loss: 0.3116 - acc: 0.8714\n",
      "Epoch 7/10\n",
      "120000/120000 [==============================] - 36s 303us/step - loss: 0.3099 - acc: 0.8720\n",
      "Epoch 8/10\n",
      "120000/120000 [==============================] - 39s 328us/step - loss: 0.3084 - acc: 0.8728\n",
      "Epoch 9/10\n",
      "120000/120000 [==============================] - 40s 333us/step - loss: 0.3073 - acc: 0.8734\n",
      "Epoch 10/10\n",
      "120000/120000 [==============================] - 41s 342us/step - loss: 0.3067 - acc: 0.8740\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f59f0da1c88>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.init(project=\"siamese\")\n",
    "model.fit([pairs_train[:,0], pairs_train[:,1]], labels_train[:], batch_size=16, epochs=10, callbacks=[WandbCallback()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[5.0249963],\n",
       "       [5.2089543],\n",
       "       [3.201473 ],\n",
       "       [5.975113 ],\n",
       "       [4.750642 ],\n",
       "       [5.159156 ],\n",
       "       [3.20649  ],\n",
       "       [4.343628 ],\n",
       "       [4.144156 ],\n",
       "       [4.6518426],\n",
       "       [5.0975466],\n",
       "       [5.408521 ],\n",
       "       [1.4890893],\n",
       "       [4.7235627],\n",
       "       [6.125441 ],\n",
       "       [4.906605 ],\n",
       "       [5.3288717],\n",
       "       [5.8312774],\n",
       "       [3.9407134],\n",
       "       [3.7102926],\n",
       "       [3.0666695],\n",
       "       [4.858267 ],\n",
       "       [3.1148467],\n",
       "       [5.1245933],\n",
       "       [5.438567 ],\n",
       "       [6.089241 ],\n",
       "       [4.323873 ],\n",
       "       [5.2438707],\n",
       "       [3.6994154],\n",
       "       [8.483834 ],\n",
       "       [2.2479403],\n",
       "       [5.1236877],\n",
       "       [5.4368477],\n",
       "       [4.468987 ],\n",
       "       [5.3873725],\n",
       "       [6.7580404],\n",
       "       [3.3060052],\n",
       "       [5.1335177],\n",
       "       [4.319157 ],\n",
       "       [5.0302362],\n",
       "       [4.6010137],\n",
       "       [5.50213  ],\n",
       "       [4.0740275],\n",
       "       [5.5968795],\n",
       "       [3.6118877],\n",
       "       [6.381815 ],\n",
       "       [2.0798607],\n",
       "       [5.61024  ],\n",
       "       [4.9588966],\n",
       "       [5.2881026],\n",
       "       [4.3802567],\n",
       "       [5.4582043],\n",
       "       [3.44393  ],\n",
       "       [6.543419 ],\n",
       "       [4.0182133],\n",
       "       [6.253499 ],\n",
       "       [5.2035193],\n",
       "       [4.6923203],\n",
       "       [4.7950625],\n",
       "       [5.372182 ],\n",
       "       [5.810975 ],\n",
       "       [4.1489024],\n",
       "       [5.3010826],\n",
       "       [6.623857 ],\n",
       "       [5.3874826],\n",
       "       [5.415547 ],\n",
       "       [4.442915 ],\n",
       "       [4.873055 ],\n",
       "       [7.6372705],\n",
       "       [4.8424   ],\n",
       "       [5.0772877],\n",
       "       [3.8946657],\n",
       "       [5.032755 ],\n",
       "       [5.1405115],\n",
       "       [4.931571 ],\n",
       "       [7.632647 ],\n",
       "       [4.0060463],\n",
       "       [4.628199 ],\n",
       "       [4.618758 ],\n",
       "       [5.600623 ],\n",
       "       [2.2887163],\n",
       "       [3.9961503],\n",
       "       [3.0127811],\n",
       "       [5.9679093],\n",
       "       [4.1191463],\n",
       "       [4.0554   ],\n",
       "       [4.4896884],\n",
       "       [4.9444394],\n",
       "       [5.494559 ],\n",
       "       [6.010549 ],\n",
       "       [3.5746217],\n",
       "       [4.6270595],\n",
       "       [6.5535994],\n",
       "       [5.856349 ],\n",
       "       [4.723534 ],\n",
       "       [4.8173413],\n",
       "       [5.17779  ],\n",
       "       [4.6616287],\n",
       "       [3.3123584],\n",
       "       [4.677595 ]], dtype=float32)"
      ]
     },
     "execution_count": 206,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
